from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader

import torch
import argparse
import tqdm
import os
import numpy as np



class Dataset(object):
    def __init__(self, files):
        self.dataset = files
        self.preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Resize((400,400))
        ])
        # add transforms as well

    def __getitem__(self, idx):
        item = self.dataset[idx]
        input_image = Image.open(item).convert('RGB')
        # add transforms
        return self.preprocess(input_image), item

    def __len__(self):
        return len(self.dataset)

def touch(fname, times=None):
    with open(fname, 'a'):
        os.utime(fname, times)


parser = argparse.ArgumentParser(
    description="PyTorch Semantic Segmentation Testing"
)
parser.add_argument(
    "--img_list",default='',
    type=str,
    help="an image path, or a directory name"
)
parser.add_argument(
    "--save_folder",default='',
    type=str,
    help="an image path, or a directory name"
)
args = parser.parse_args()

if args.img_list != '':
    # f_list = open(args.img_list, 'r')
    with open(args.img_list) as f:
        files = f.read().splitlines()

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()
dl = DataLoader(Dataset(files), batch_size=16, pin_memory=True, num_workers=4)
model.to('cuda')


for batch in tqdm.tqdm(dl,total=len(dl)):
    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = batch[0].to('cuda')
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    for file, p in zip(batch[1],output_predictions):
        p = p.cpu()
        save_name = args.save_folder + file.split('.')[0] + ".txt"
        dirname = os.path.dirname(save_name)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        with open(save_name, "w") as fd:
            has_person = 1 if sum([int(c) == 15 for c in p.flatten()]) > 10 else 0
            fd.write(str(has_person) + " \n")
            classes, scores = np.unique(p.flatten(), return_counts=True)
            for _c, _s in zip(classes, scores):
                line = str(_c.item()) + ":" + str(_s.item()) + " \n"
                fd.write(line)
        if has_person:
            touch(os.path.join(args.save_folder, 'person', file ))
        else:
            touch(os.path.join(args.save_folder, 'noperson', file ))